{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Amazon SageMaker Neural Topic Model now supports auxiliary vocabulary channel, new topic evaluation metrics, and training subsampling"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Amazon SageMaker Neural Topic Model(NTM) is an unsupervised learning algorithm that learns the topic distributions of large collections of document corpus. With SageMaker NTM, you can build machine learning solutions for use cases such as document classification, information retrieval, and content recommendation. SageMaker provides a rich set of model training configuration options such as network architecture, automatic early stopping, as well as hyperparameters to fine-tune between a magnitude of metrics such as document modeling accuracy, human interpretability, and granularity of the learned topics.  see Introduction to the Amazon SageMaker Neural Topic Model (https://aws.amazon.com/blogs/machine-learning/introduction-to-the-amazon-sagemaker-neural-topic-model/) if you are not already familiar with SageMaker and SageMaker NTM.\n",
    "\n",
    "If you are new to machine learning, or want to free up time to focus on other tasks, then the fully automated Amazon Comprehend topic modeling API is the best option.  If you are a data science specialist looking for finer control over the various layers of building and tuning your own topic modeling model, then the Amazon SageMaker NTM might work better for you. For example, let’s say you are building a document topic tagging application that needs a customized vocabulary, and you need the ability to adjust the algorithm hyperparameters, such as the number of layers of the neural network, so you can train a topic model that meets the target accuracy in terms of coherence and uniqueness scores. In this case, the Amazon SageMaker NTM would be the appropriate tool to use.\n",
    "\n",
    "In this blog, we want to introduce 3 new features of the SageMaker NTM that would help improve productivity, enhance topic coherence evaluation capability, and speed up model training. \n",
    "\n",
    "- Auxiliary vocabulary channel\n",
    "- Word Embedding Topic Coherence (WETC) and Topic Uniqueness (TU) metrics\n",
    "- Subsampling data during training\n",
    "\n",
    "In addition to these new features, by optimizing sparse operations and the parameter server, we have improved the speed of the algorithm by 2x for training and 4x for evaluation on a single GPU. The speedup is even more significant for multi-gpu training. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Auxiliary vocabulary channel\n",
    "When an NTM training job runs, it outputs the training status and evaluation metrics to the CloudWatch logs. Among the outputs are lists of top words detected for each of the learned topics. Prior to the availability of auxiliary vocabulary channel support, the top words were represented as integers, and customers needed to map the integers to an external custom vocabulary lookup table in order to know what the actual words were. With the support of auxiliary vocabulary channel, users can now add a vocabulary file as an additional data input channel, and SageMaker NTM will output the actual words in a topic instead of integers. This feature eliminates the manual effort needed to map integers to the actual vocabulary. Below is a sample of what a custom vocabulary text file look like. The text file will simply contain a list of words, one word per row, in the order corresponding to the integer IDs provided in the data.\n",
    "```\n",
    "absent\n",
    "absentee\n",
    "absolute\n",
    "absolutely\n",
    "```\n",
    "\n",
    "To include an auxiliary vocabulary for a training job, you should name the vocabulary file `vocab.txt` and place it in the auxiliary channel. See the code example below for how to add the auxiliary vocabulary file. In this release we only support the UTF-8 encoding for the vocabulary file. \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Word Embedding Topic Coherence metrics\n",
    "To evaluate the performance of an trained SageMaker NTM model, customers can examine the perplexity metric emitted by the training job.  Sometimes, however customers also want to evaluate the topic coherence of a model that measures the closeness of the words in a topic. A good topic should have semantically similar words in it. Traditional methods like the Normalized Point-wise Mutual Information(NPMI), while widely accepted, require a large external corpus.  The new WETC metric measures the similarity of words in a topic by using a pre-trained word embedding, [Glove-6B-400K-50d](https://nlp.stanford.edu/projects/glove/). \n",
    "\n",
    "Intuitively, each word in the vocabulary is given a vector representation (embedding). We compute the WETC of a topic by averaging the pair-wise cosine similarities between the vectors corresponding to the top words of the topic. Finally, we average the WETC for all the topics to obtain a single score for the model. \n",
    "\n",
    "Our tests have shown that WETC correlates very well with NPMI as an effective surrogate. For details about the pair-wise WETC computation and its correlation to NPMI, please refer to our paper [Coherence-Aware Neural Topic Modeling, Ding et. al. 2018 (Accepted for EMNLP 2018)](https://arxiv.org/pdf/1809.02687.pdf)\n",
    "\n",
    "WETC ranges between 0 and 1, higher is better. Typical value would be in the range of 0.2 to 0.8. The WETC metric is evaluated whenever the vocabulary file is provided. The average WETC score over the topics is displayed in the log above the top words of all topics. The WETC metric for each topic is also displayed along with the top words of each topic. Please refer to the screenshot below for an example.\n",
    "\n",
    "> Note in case many of the words in the supplied vocabulary cannot be found in the pre-trained word embedding, the WETC score can be misleading. Therefore we provide a warning message to alert the user exactly how many words in the vocabulary do not have an embedding:\n",
    "\n",
    "```\n",
    "[09/07/2018 14:18:57 WARNING 140296605947712] 69 out of 16648 in vocabulary do not have embeddings! Default vector used for unknown embedding!\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Log with WETC metrics](WETC_screenshot.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Topic Uniqueness metric\n",
    "\n",
    "A good topic modeling algorithm should generate topics that are unique to avoid topic duplication. Customers who want to understand the topic uniqueness of a trained Amazon SageMaker NTM model to evaluate its quality can now use the new TU metric. \n",
    "To understand how TU works, suppose there are K topics and we extract the top n words for each topic, the TU for topic k is defined as\n",
    "![TU definition](TU_definition.png)\n",
    "\n",
    "The range of the TU value is between 1/K and 1, where K is the number of topics. A higher TU value represents higher topic uniqueness for the topics detected.\n",
    "\n",
    "The TU score is displayed regardless of the existence of a vocabulary file. Similar to the WETC, the average TU score over the topics is displayed in the log above the top words of all topics; the TU score for each topic is also displayed along with the top words of each topic. Please refer to the screenshot below for an example."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Log with TU metrics](TU_screenshot.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we introduce a new hyperparameter for subsampling the data during training\n",
    "\n",
    "## Subsampling data during training\n",
    "\n",
    "In typical online training, the entire training dataset is fed into the training algorithm for each epoch.  When the corpus is large, this leads to long training time.  With effective subsampling of the training dataset, we can achieve faster model convergence while maintaining the model performance.  The new subsampling feature of the SageMaker NMT allows customers to specify a percentage of training data used for training using a new hyperparameter, `sub_sample`.  For example, specifying 0.8 for `sub_sample` would direct SageMaker NTM to use 80% of training data randomly for each epoch. As a result, the algorithm will stochastically cover different subsets of data during different epochs. You can configure this value in both the SageMaker console or directly training code. See sample code below on how to set this value for training.\n",
    "\n",
    "```\n",
    "ntm.set_hyperparameters(num_topics=num_topics, feature_dim=vocab_size, mini_batch_size=128, \n",
    "                        epochs=100, sub_sample=0.7)\n",
    "```\n",
    "At the end of this notebook we will demonstrate that using subsampling can reduce the overall training time for large dataset and potentially achieve higher topic uniqueness and coherence."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "Finally, to illustrate the new features, we will go through an example with the public Wikitext launguage modeling dataset.\n",
    "\n",
    "## Data Preparation\n",
    "\n",
    "The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia. The dataset is available under the Creative Commons Attribution-ShareAlike License. The dataset can be downloaded from [here](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset). We will first use the wikitext-2 dataset. \n",
    "\n",
    "> **Acknowledgements:** Stephen Merity, Caiming Xiong, James Bradbury, and Richard Socher. 2016. Pointer Sentinel Mixture Models "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fetching Data Set\n",
    "\n",
    "First let's define the folder to hold the data and clean the content in it which might be from previous experiments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "        \n",
    "def check_create_dir(dir):\n",
    "    if os.path.exists(dir):  # cleanup existing data folder\n",
    "        shutil.rmtree(dir)\n",
    "    os.mkdir(dir)\n",
    "\n",
    "dataset = 'wikitext-2'\n",
    "current_dir = os.getcwd()\n",
    "data_dir = os.path.join(current_dir,dataset)\n",
    "check_create_dir(data_dir)\n",
    "os.chdir(data_dir)\n",
    "print('Current directory: ', os.getcwd())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can download and unzip the data. *Please review the following Acknowledgements, Copyright Information, and Availability notice before downloading the data.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# **Acknowledgements, Copyright Information, and Availability**\n",
    "# This dataset is available under the Creative Commons Attribution-ShareAlike License\n",
    "# Source: https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset\n",
    "\n",
    "!curl -O https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip\n",
    "!unzip wikitext-2-v1.zip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A sample of the `wiki.valid.tokens` is shown below. The datasets contains markdown text with all documents (articles) concatenated. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "\n",
    " = Homarus gammarus =\n",
    "\n",
    " Homarus gammarus , known as the European lobster or common lobster , is a species of <unk> lobster from the eastern Atlantic Ocean , Mediterranean Sea and parts of the Black Sea . It is closely related to the American lobster , H. americanus . It may grow to a length of 60 cm ( 24 in ) and a mass of 6 kilograms ( 13 lb ) , and bears a conspicuous pair of claws . In life , the lobsters are blue , only becoming \" lobster red \" on cooking . Mating occurs in the summer , producing eggs which are carried by the females for up to a year before hatching into <unk> larvae . Homarus gammarus is a highly esteemed food , and is widely caught using lobster pots , mostly around the British Isles .\n",
    "\n",
    " = = Description = =\n",
    "\n",
    " Homarus gammarus is a large <unk> , with a body length up to 60 centimetres ( 24 in ) and weighing up to 5 – 6 kilograms ( 11 – 13 lb ) , although the lobsters caught in lobster pots are usually 23 – 38 cm ( 9 – 15 in ) long and weigh 0 @.@ 7 – 2 @.@ 2 kg ( 1 @.@ 5 – 4 @.@ 9 lb ) . Like other crustaceans , lobsters have a hard <unk> which they must shed in order to grow , in a process called <unk> ( <unk> ) . This may occur several times a year for young lobsters , but decreases to once every 1 – 2 years for larger animals .\n",
    " ```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preprocessing\n",
    "We need to first parse the input files into separate documents. We can identify each document by its title in level-1 heading. Additional care is taken to check that the line containing the title should be sandwiched by blank lines to avoid false detection of document titles. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def is_document_start(line):\n",
    "    if len(line) < 4:\n",
    "        return False\n",
    "    if line[0] is '=' and line[-1] is '=':\n",
    "        if line[2] is not '=':\n",
    "            return True\n",
    "        else:\n",
    "            return False\n",
    "    else:\n",
    "        return False\n",
    "\n",
    "\n",
    "def token_list_per_doc(input_dir, token_file):\n",
    "    lines_list = []\n",
    "    line_prev = ''\n",
    "    prev_line_start_doc = False\n",
    "    with open(os.path.join(input_dir, token_file), 'r', encoding='utf-8') as f:\n",
    "        for l in f:\n",
    "            line = l.strip()\n",
    "            if prev_line_start_doc and line:\n",
    "                # the previous line should not have been start of a document!\n",
    "                lines_list.pop()\n",
    "                lines_list[-1] = lines_list[-1] + ' ' + line_prev\n",
    "\n",
    "            if line:\n",
    "                if is_document_start(line) and not line_prev:\n",
    "                    lines_list.append(line)\n",
    "                    prev_line_start_doc = True\n",
    "                else:\n",
    "                    lines_list[-1] = lines_list[-1] + ' ' + line\n",
    "                    prev_line_start_doc = False\n",
    "            else:\n",
    "                prev_line_start_doc = False\n",
    "            line_prev = line\n",
    "\n",
    "    print(\"{} documents parsed!\".format(len(lines_list)))\n",
    "    return lines_list\n",
    "\n",
    "input_dir = os.path.join(data_dir, dataset)\n",
    "train_file = 'wiki.train.tokens'\n",
    "val_file = 'wiki.valid.tokens'\n",
    "test_file = 'wiki.test.tokens'\n",
    "train_doc_list = token_list_per_doc(input_dir, train_file)\n",
    "val_doc_list = token_list_per_doc(input_dir, val_file)\n",
    "test_doc_list = token_list_per_doc(input_dir, test_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the following cell, we use a lemmatizer from `nltk`. In the list comprehension, we implement a simple rule: only consider words that are longer than 2 characters, start with a letter and match the `token_pattern`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install nltk\n",
    "import nltk\n",
    "# nltk.download('punkt')\n",
    "nltk.download('wordnet')\n",
    "from nltk.stem import WordNetLemmatizer \n",
    "import re\n",
    "token_pattern = re.compile(r\"(?u)\\b\\w\\w+\\b\")\n",
    "class LemmaTokenizer(object):\n",
    "    def __init__(self):\n",
    "        self.wnl = WordNetLemmatizer()\n",
    "    def __call__(self, doc):\n",
    "        return [self.wnl.lemmatize(t) for t in doc.split() if len(t) >= 2 and re.match(\"[a-z].*\",t) \n",
    "                and re.match(token_pattern, t)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We perform lemmatizing and counting next."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "from sklearn.feature_extraction.text import CountVectorizer\n",
    "\n",
    "print('Lemmatizing and counting, this may take a few minutes...')\n",
    "start_time = time.time()\n",
    "vectorizer = CountVectorizer(input='content', analyzer='word', stop_words='english',\n",
    "                             tokenizer=LemmaTokenizer(), max_df=0.9, min_df=3)\n",
    "\n",
    "train_vectors = vectorizer.fit_transform(train_doc_list)\n",
    "val_vectors = vectorizer.transform(val_doc_list)\n",
    "test_vectors = vectorizer.transform(test_doc_list)\n",
    "\n",
    "vocab_list = vectorizer.get_feature_names()\n",
    "vocab_size = len(vocab_list)\n",
    "print('vocab size:', vocab_size)\n",
    "print('Done. Time elapsed: {:.2f}s'.format(time.time() - start_time))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Because all the parameters (weights and biases) in the NTM model are `np.float32` type we'd need the input data to also be in `np.float32`. It is better to do this type-casting upfront rather than repeatedly casting during mini-batch training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy.sparse as sparse\n",
    "\n",
    "def shuffle_and_dtype(vectors):\n",
    "    idx = np.arange(vectors.shape[0])\n",
    "    np.random.shuffle(idx)\n",
    "    vectors = vectors[idx]\n",
    "    vectors = sparse.csr_matrix(vectors, dtype=np.float32)\n",
    "    print(type(vectors), vectors.dtype)\n",
    "    return vectors\n",
    "\n",
    "train_vectors = shuffle_and_dtype(train_vectors)\n",
    "val_vectors = shuffle_and_dtype(val_vectors)\n",
    "test_vectors = shuffle_and_dtype(test_vectors)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The NTM algorithm, as well as other first-party SageMaker algorithms, accepts data in [RecordIO](https://mxnet.apache.org/api/python/io/io.html#module-mxnet.recordio) [Protobuf](https://developers.google.com/protocol-buffers/) format. Here we define a helper function to convert the data to RecordIO Protobuf format. In addition, we will have the option to split the data into several parts specified by `n_parts`.\n",
    "\n",
    "The algorithm inherently supports multiple files in the training folder (\"channel\"), which could be very helpful for large data sets. In addition, when we use distributed training with multiple workers (compute instances), having multiple files allows us to distribute different portions of the training data to different workers conveniently.\n",
    "\n",
    "Inside this helper function we use `write_spmatrix_to_sparse_tensor` function provided by [SageMaker Python SDK](https://github.com/aws/sagemaker-python-sdk) to convert scipy sparse matrix into RecordIO Protobuf format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_convert(sparray, prefix, fname_template='data_part{}.pbr', n_parts=2):\n",
    "    import io\n",
    "    import sagemaker.amazon.common as smac\n",
    "\n",
    "    chunk_size = sparray.shape[0] // n_parts\n",
    "    for i in range(n_parts):\n",
    "\n",
    "        # Calculate start and end indices\n",
    "        start = i * chunk_size\n",
    "        end = (i + 1) * chunk_size\n",
    "        if i + 1 == n_parts:\n",
    "            end = sparray.shape[0]\n",
    "\n",
    "        # Convert to record protobuf\n",
    "        buf = io.BytesIO()\n",
    "        smac.write_spmatrix_to_sparse_tensor(array=sparray[start:end], file=buf, labels=None)\n",
    "        buf.seek(0)\n",
    "\n",
    "        fname = os.path.join(prefix, fname_template.format(i))\n",
    "        with open(fname, 'wb') as f:\n",
    "            f.write(buf.getvalue())\n",
    "        print('Saved data to {}'.format(fname))\n",
    "        \n",
    "train_data_dir = os.path.join(data_dir, 'train')\n",
    "val_data_dir = os.path.join(data_dir, 'validation')\n",
    "test_data_dir = os.path.join(data_dir, 'test')\n",
    "\n",
    "check_create_dir(train_data_dir)\n",
    "check_create_dir(val_data_dir)\n",
    "check_create_dir(test_data_dir)\n",
    "\n",
    "split_convert(train_vectors, prefix=train_data_dir, fname_template='train_part{}.pbr', n_parts=4)\n",
    "split_convert(val_vectors, prefix=val_data_dir, fname_template='val_part{}.pbr', n_parts=1)\n",
    "split_convert(test_vectors, prefix=test_data_dir, fname_template='test_part{}.pbr', n_parts=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save the vocabulary file\n",
    "To make use of the auxiliary channel for vocabulary file, we first save the text file with the name `vocab.txt` in the auxiliary directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "aux_data_dir = os.path.join(data_dir, 'auxiliary')\n",
    "check_create_dir(aux_data_dir)\n",
    "with open(os.path.join(aux_data_dir, 'vocab.txt'), 'w', encoding='utf-8') as f:\n",
    "    for item in vocab_list:\n",
    "        f.write(item+'\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Store Data on S3\n",
    "\n",
    "Below we upload the data to an Amazon S3 destination for the model to access it during training.\n",
    "\n",
    "#### Setup AWS Credentials\n",
    "\n",
    "We first need to specify data locations and access roles. ***This is the only cell of this notebook that you will need to edit.*** In particular, we need the following data:\n",
    "\n",
    "- The S3 `bucket` and `prefix` that you want to use for training and model data.  This should be within the same region as the Notebook Instance, training, and hosting.\n",
    "- The IAM `role` is used to give training and hosting access to your data. See the documentation for how to create these.  Note, if more than one role is required for notebook instances, training, and/or hosting, please replace the boto regexp with a the appropriate full IAM role arn string(s)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sagemaker\n",
    "\n",
    "role = sagemaker.get_execution_role()\n",
    "\n",
    "bucket = sagemaker.Session().default_bucket() #<or insert your own bucket name>#\n",
    "prefix = 'ntm/' + dataset\n",
    "\n",
    "train_prefix = os.path.join(prefix, 'train')\n",
    "val_prefix = os.path.join(prefix, 'val')\n",
    "aux_prefix = os.path.join(prefix, 'auxiliary')\n",
    "test_prefix = os.path.join(prefix, 'test')\n",
    "output_prefix = os.path.join(prefix, 'output')\n",
    "\n",
    "s3_train_data = os.path.join('s3://', bucket, train_prefix)\n",
    "s3_val_data = os.path.join('s3://', bucket, val_prefix)\n",
    "s3_aux_data = os.path.join('s3://', bucket, aux_prefix)\n",
    "s3_test_data = os.path.join('s3://', bucket, test_prefix)\n",
    "output_path = os.path.join('s3://', bucket, output_prefix)\n",
    "print('Training set location', s3_train_data)\n",
    "print('Validation set location', s3_val_data)\n",
    "print('Auxiliary data location', s3_aux_data)\n",
    "print('Test data location', s3_test_data)\n",
    "print('Trained model will be saved at', output_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Upload the input directories to s3\n",
    "We use the `aws` command line interface (CLI) to upload the various input channels. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import subprocess\n",
    "\n",
    "cmd_train = 'aws s3 cp ' + train_data_dir + ' ' + s3_train_data + ' --recursive' \n",
    "p=subprocess.Popen(cmd_train, shell=True,stdout=subprocess.PIPE)\n",
    "p.communicate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cmd_val = 'aws s3 cp ' + val_data_dir + ' ' + s3_val_data + ' --recursive' \n",
    "p=subprocess.Popen(cmd_val, shell=True,stdout=subprocess.PIPE)\n",
    "p.communicate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cmd_test = 'aws s3 cp ' + test_data_dir + ' ' + s3_test_data + ' --recursive' \n",
    "p=subprocess.Popen(cmd_test, shell=True,stdout=subprocess.PIPE)\n",
    "p.communicate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cmd_aux = 'aws s3 cp ' + aux_data_dir + ' ' + s3_aux_data + ' --recursive' \n",
    "p=subprocess.Popen(cmd_aux, shell=True,stdout=subprocess.PIPE)\n",
    "p.communicate()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model Training\n",
    "We have prepared the train, validation, test and auxiliary input channels on s3. Next, we configure a SageMaker training job to use the NTM algorithm on the data we prepared. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SageMaker uses Amazon Elastic Container Registry (ECR) docker container to host the NTM training image. The following ECR containers are currently available for SageMaker NTM training in different regions. For the latest Docker container registry please refer to [Amazon SageMaker: Common Parameters](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-algo-docker-registry-paths.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "from sagemaker.amazon.amazon_estimator import get_image_uri\n",
    "container = get_image_uri(boto3.Session().region_name, 'ntm')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The code in the cell below automatically chooses an algorithm container based on the current region. In the API call to `sagemaker.estimator.Estimator` we also specify the type and count of instances for the training job. Because the wikitext-2 data set is relatively small, we have chosen a CPU only instance (`ml.c4.xlarge`), but do feel free to change to [other instance types](https://aws.amazon.com/sagemaker/pricing/instance-types/). NTM fully takes advantage of GPU hardware and in general trains roughly an order of magnitude faster on a GPU than on a CPU. Multi-GPU or multi-instance training further improves training speed roughly linearly if communication overhead is low compared to compute time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "sess = sagemaker.Session()\n",
    "ntm = sagemaker.estimator.Estimator(container,\n",
    "                                    role, \n",
    "                                    train_instance_count=1, \n",
    "                                    train_instance_type='ml.c4.xlarge',\n",
    "                                    output_path=output_path,\n",
    "                                    sagemaker_session=sess)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can specify the hyperparameters, including the newly introduced `sub_sample`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_topics = 20\n",
    "ntm.set_hyperparameters(num_topics=num_topics, feature_dim=vocab_size, mini_batch_size=60, \n",
    "                        epochs=50, sub_sample=0.7)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we need to specify how the data will be distributed to the workers during training as well as their content type. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.session import s3_input\n",
    "s3_train = s3_input(s3_train_data, distribution='ShardedByS3Key',\n",
    "                    content_type='application/x-recordio-protobuf')\n",
    "s3_val = s3_input(s3_val_data, distribution='FullyReplicated',\n",
    "                  content_type='application/x-recordio-protobuf')\n",
    "s3_test = s3_input(s3_test_data, distribution='FullyReplicated',\n",
    "                  content_type='application/x-recordio-protobuf')\n",
    "\n",
    "s3_aux = s3_input(s3_aux_data, distribution='FullyReplicated', content_type='text/plain')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We are ready to run the training job. Again, we will notice in the log that the top words are printed together with the WETC and TU scores."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ntm.fit({'train': s3_train, 'validation': s3_val, 'auxiliary': s3_aux, 'test': s3_test})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once the job is completed, you can view information about and the status of a training job using the AWS SageMaker console. Just click on the \"Jobs\" tab and select training job matching the training job name, below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Training job name: {}'.format(ntm.latest_training_job.job_name))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We demonstrate the utility of the `sub_sample` hyperparameter by setting it to 1.0 and 0.2 for training on the wikitext-103 dataset (simply change the dataset name and download URL, re-start the kernel and re-run this notebook). In both settings, we set `epochs = 100` and NTM would early-exit training when the loss on validation data does not improve in 3 consecutive epochs. We report the TU, WETC, NPMI of the best epoch based on validation loss as well as the total time for both settings below. Note we ran each training job on a single GPU of a `p2.8xlarge` machine.\n",
    "![subsample_result_table](subsample_table.png)\n",
    "We observe that setting sub_sample to 0.2 leads to reduced total training time even though it takes more epochs to converge (49 instead of 18). The increase in the number of epochs to convergence is expected due to the variance introduced by training on a random subset of data per epoch. Yet the overall training time is reduced because training is about 5 times faster per epoch at the subsampling rate of 0.2. We also note the higher scores in terms of TU, WETC and NPMI at the end of training with subsampling. \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Conclusion\n",
    "\n",
    "In this blog post, we introduced 3 new features of SageMaker NTM algorithm. Upon reading this blog and completing the new sample notebook, you should have learned how to add an auxiliary vocabulary channel to automatically map integer representations of words in the detected topics to a human understandable vocabulary. You also have learned to evaluate the quality of the a model using both Word Embedding Topic Coherence and Topic Uniqueness metrics. Lastly, you learned to use the subsampling feature to reduce the model training time while maintaining similar model performance."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  },
  "notice": "Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.  Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
